Code
from BayesForge import bf
import jax.numpy as jnp
m = bf(platform='cpu')
# Load and prepare data
df = m.load.phylo_simple()
L_df = m.load.phylo_L_simple()
L = L_df.values
species_to_idx = {sp: i for i, sp in enumerate(L_df.columns)}
df["phylo_idx"] = df["phylo"].map(species_to_idx)
m.data_on_model = {
"y": jnp.array(df["y"].values),
"x": jnp.array(df["x"].values),
"phylo_idx": jnp.array(df["phylo_idx"].values, dtype=jnp.int32),
"L": jnp.array(L)
}
def model(y, x, phylo_idx, L):
# Priors
intercept = m.dist.normal(0, 50, name="intercept")
b_x = m.dist.normal(0, 10, name="b_x")
# Standard deviation for phylogenetic effect
sd_phylo = m.dist.half_normal(20, name="sd_phylo")
sigma = m.dist.half_normal(20, name="sigma")
# Phylogenetic random effect (non-centered parameterization)
num_species = L.shape[0]
z_phylo = m.dist.normal(jnp.zeros(num_species), 1.0, name="z_phylo")
u_phylo = jnp.matmul(L, z_phylo) * sd_phylo
# Linear predictor
mu = intercept + b_x * x + u_phylo[phylo_idx]
# Likelihood
m.dist.normal(mu, sigma, name="obs", obs=y)
m.fit(model)/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
bf v 0.0.48 package loaded
jax.local_device_count 32
0%| | 0/2000 [00:00<?, ?it/s]Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]
0%| | 0/2000 [00:00<?, ?it/s]
Compiling.. : 0%| | 0/2000 [00:00<?, ?it/s]Running chain 0: 0%| | 0/2000 [00:01<?, ?it/s]
Running chain 1: 0%| | 0/2000 [00:01<?, ?it/s]
Running chain 2: 0%| | 0/2000 [00:01<?, ?it/s]
Running chain 3: 0%| | 0/2000 [00:01<?, ?it/s]
Running chain 2: 5%|β | 100/2000 [00:01<00:02, 668.22it/s]
Running chain 3: 5%|β | 100/2000 [00:01<00:03, 624.14it/s]Running chain 0: 5%|β | 100/2000 [00:01<00:03, 574.53it/s]
Running chain 1: 5%|β | 100/2000 [00:01<00:03, 528.09it/s]
Running chain 2: 10%|β | 200/2000 [00:01<00:02, 820.01it/s]
Running chain 3: 10%|β | 200/2000 [00:01<00:02, 773.07it/s]Running chain 0: 10%|β | 200/2000 [00:01<00:02, 717.29it/s]
Running chain 1: 10%|β | 200/2000 [00:01<00:02, 690.72it/s]
Running chain 2: 15%|ββ | 300/2000 [00:01<00:01, 876.73it/s]
Running chain 3: 15%|ββ | 300/2000 [00:01<00:02, 828.25it/s]
Running chain 1: 15%|ββ | 300/2000 [00:02<00:02, 751.79it/s]
Running chain 2: 20%|ββ | 400/2000 [00:02<00:01, 919.12it/s]Running chain 0: 20%|ββ | 400/2000 [00:02<00:01, 886.15it/s]
Running chain 3: 20%|ββ | 400/2000 [00:02<00:01, 854.76it/s]
Running chain 1: 20%|ββ | 400/2000 [00:02<00:02, 786.21it/s]
Running chain 2: 25%|βββ | 500/2000 [00:02<00:01, 918.90it/s]
Running chain 3: 25%|βββ | 500/2000 [00:02<00:01, 899.60it/s]Running chain 0: 25%|βββ | 500/2000 [00:02<00:01, 883.67it/s]
Running chain 1: 25%|βββ | 500/2000 [00:02<00:01, 805.21it/s]
Running chain 2: 35%|ββββ | 700/2000 [00:02<00:01, 1024.33it/s]
Running chain 3: 35%|ββββ | 700/2000 [00:02<00:01, 1024.98it/s]Running chain 0: 35%|ββββ | 700/2000 [00:02<00:01, 974.37it/s]
Running chain 1: 35%|ββββ | 700/2000 [00:02<00:01, 934.10it/s]
Running chain 2: 45%|βββββ | 900/2000 [00:02<00:01, 1075.93it/s]
Running chain 3: 45%|βββββ | 900/2000 [00:02<00:01, 1091.41it/s]Running chain 0: 45%|βββββ | 900/2000 [00:02<00:01, 1043.35it/s]
Running chain 1: 45%|βββββ | 900/2000 [00:02<00:01, 1069.28it/s]Running chain 0: 50%|βββββ | 1000/2000 [00:02<00:00, 1008.98it/s]
Running chain 2: 55%|ββββββ | 1100/2000 [00:02<00:00, 1023.69it/s]
Running chain 3: 55%|ββββββ | 1100/2000 [00:02<00:00, 1041.89it/s]Running chain 0: 55%|ββββββ | 1100/2000 [00:02<00:00, 1006.41it/s]
Running chain 1: 55%|ββββββ | 1100/2000 [00:02<00:00, 997.50it/s]Running chain 0: 60%|ββββββ | 1200/2000 [00:02<00:00, 975.44it/s]
Running chain 1: 60%|ββββββ | 1200/2000 [00:02<00:00, 985.77it/s]
Running chain 2: 65%|βββββββ | 1300/2000 [00:02<00:00, 997.54it/s]
Running chain 3: 65%|βββββββ | 1300/2000 [00:02<00:00, 992.63it/s] Running chain 0: 65%|βββββββ | 1300/2000 [00:02<00:00, 973.10it/s]
Running chain 1: 65%|βββββββ | 1300/2000 [00:03<00:00, 982.13it/s]
Running chain 2: 70%|βββββββ | 1400/2000 [00:03<00:00, 988.85it/s]
Running chain 3: 70%|βββββββ | 1400/2000 [00:03<00:00, 979.90it/s]
Running chain 1: 70%|βββββββ | 1400/2000 [00:03<00:00, 964.28it/s]
Running chain 3: 75%|ββββββββ | 1500/2000 [00:03<00:00, 978.97it/s]Running chain 0: 75%|ββββββββ | 1500/2000 [00:03<00:00, 986.75it/s]
Running chain 1: 75%|ββββββββ | 1500/2000 [00:03<00:00, 968.55it/s]
Running chain 2: 80%|ββββββββ | 1600/2000 [00:03<00:00, 997.38it/s]
Running chain 3: 80%|ββββββββ | 1600/2000 [00:03<00:00, 975.58it/s]
Running chain 1: 80%|ββββββββ | 1600/2000 [00:03<00:00, 970.29it/s]
Running chain 2: 85%|βββββββββ | 1700/2000 [00:03<00:00, 995.01it/s]Running chain 0: 85%|βββββββββ | 1700/2000 [00:03<00:00, 990.69it/s]
Running chain 3: 90%|βββββββββ | 1800/2000 [00:03<00:00, 989.53it/s]Running chain 0: 90%|βββββββββ | 1800/2000 [00:03<00:00, 989.32it/s]
Running chain 1: 90%|βββββββββ | 1800/2000 [00:03<00:00, 1002.29it/s]
Running chain 2: 95%|ββββββββββ| 1900/2000 [00:03<00:00, 1012.20it/s]Running chain 0: 95%|ββββββββββ| 1900/2000 [00:03<00:00, 966.88it/s]Running chain 2: 100%|ββββββββββ| 2000/2000 [00:03<00:00, 549.75it/s]
Running chain 3: 100%|ββββββββββ| 2000/2000 [00:03<00:00, 996.90it/s]Running chain 3: 100%|ββββββββββ| 2000/2000 [00:03<00:00, 544.34it/s]
Running chain 0: 100%|ββββββββββ| 2000/2000 [00:03<00:00, 540.54it/s]
Running chain 1: 100%|ββββββββββ| 2000/2000 [00:03<00:00, 1016.62it/s]Running chain 1: 100%|ββββββββββ| 2000/2000 [00:03<00:00, 535.31it/s]